{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3c3d93a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Using sep_token, but it is not set yet.\n",
      "Using cls_token, but it is not set yet.\n",
      "Using mask_token, but it is not set yet.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GPT2TokenizerFast(name_or_path='tokenizer/gpt2-large', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '!'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\n",
       "\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "import torch\n",
    "\n",
    "checkpoint = 110\n",
    "device = 'cuda'\n",
    "dtype = torch.float16\n",
    "only_test = True\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('tokenizer/gpt2-large')\n",
    "tokenizer.pad_token_id = 0\n",
    "\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "58af6f4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DatasetDict({\n",
       "     train: Dataset({\n",
       "         features: ['prompt', 'chosen', 'rejected'],\n",
       "         num_rows: 200\n",
       "     })\n",
       "     test: Dataset({\n",
       "         features: ['prompt', 'chosen', 'rejected'],\n",
       "         num_rows: 1500\n",
       "     })\n",
       " }),\n",
       " {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',\n",
       "  'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beats International\"',\n",
       "  'rejected': ''})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_from_disk\n",
    "\n",
    "dataset = load_from_disk('dataset/b-mc2/sql-create-context')['train']\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    question = 'context:%s question:%s answer:' % (data['context'],\n",
    "                                                   data['question'])\n",
    "    answer = data['answer']\n",
    "    return {'question': question, 'answer': answer}\n",
    "\n",
    "\n",
    "dataset = dataset.map(f, remove_columns=['context'])\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    question = len(tokenizer.encode(data['question']))\n",
    "    answer = len(tokenizer.encode(data['answer']))\n",
    "    return 25 <= question <= 65 and 10 <= answer <= 35\n",
    "\n",
    "\n",
    "dataset = dataset.filter(f)\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    return {\n",
    "        'prompt': data['question'],\n",
    "        'chosen': data['answer'],\n",
    "        'rejected': ''\n",
    "    }\n",
    "\n",
    "\n",
    "dataset = dataset.map(f, remove_columns=['question', 'answer'])\n",
    "dataset = dataset.train_test_split(test_size=1500)\n",
    "\n",
    "if only_test:\n",
    "    dataset['train'] = dataset['train'].select(range(200))\n",
    "\n",
    "dataset, dataset['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "12276e41",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'model/dpo_110.model'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "path = 'model/gpt2-large'\n",
    "if checkpoint != -1:\n",
    "    path = 'model/dpo_%d.model' % checkpoint\n",
    "\n",
    "model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)\n",
    "if not only_test:\n",
    "    model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)\n",
    "\n",
    "path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "08c4ec71",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 200/200 [00:02<00:00, 68.13 examples/s]\n",
      "Map: 100%|██████████| 1500/1500 [00:17<00:00, 87.49 examples/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DatasetDict({\n",
       "     train: Dataset({\n",
       "         features: ['prompt', 'chosen', 'rejected'],\n",
       "         num_rows: 200\n",
       "     })\n",
       "     test: Dataset({\n",
       "         features: ['prompt', 'chosen', 'rejected'],\n",
       "         num_rows: 1500\n",
       "     })\n",
       " }),\n",
       " {'prompt': 'context:CREATE TABLE table_26400075_2 (weeks_in_top_10 VARCHAR, artist VARCHAR) question:How many weeks in the top-10 did Beats International have? answer:',\n",
       "  'chosen': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beats International\"',\n",
       "  'rejected': 'SELECT weeks_in_top_10 FROM table_26400075_2 WHERE artist = \"Beatles International\"'})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#重新生成数据集中的rejected字段\n",
    "def remake_dataset():\n",
    "    global dataset\n",
    "    tokenizer.padding_side = 'left'\n",
    "    model_dpo.to(dtype)\n",
    "\n",
    "    def f(data):\n",
    "        token = tokenizer(data['prompt'],\n",
    "                          return_tensors='pt',\n",
    "                          padding=True,\n",
    "                          truncation=True).to(device)\n",
    "\n",
    "        out = model_dpo.generate(**token,\n",
    "                                 min_length=-1,\n",
    "                                 top_k=1,\n",
    "                                 top_p=1.0,\n",
    "                                 do_sample=True,\n",
    "                                 pad_token_id=tokenizer.pad_token_id,\n",
    "                                 max_new_tokens=35,\n",
    "                                 eos_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "        for i in range(len(out)):\n",
    "            lens = len(token['input_ids'][i])\n",
    "            rejected = out[i, lens:]\n",
    "\n",
    "            if tokenizer.eos_token_id in rejected:\n",
    "                lens = rejected.tolist().index(tokenizer.eos_token_id) + 1\n",
    "                rejected = rejected[:lens]\n",
    "\n",
    "            rejected = rejected[:35]\n",
    "            rejected = tokenizer.decode(rejected, skip_special_tokens=True)\n",
    "\n",
    "            if rejected == data['chosen'][i]:\n",
    "                rejected = ''\n",
    "\n",
    "            data['rejected'][i] = rejected\n",
    "\n",
    "        return data\n",
    "\n",
    "    dataset = dataset.map(f, batched=True, batch_size=128, num_proc=1)\n",
    "\n",
    "    tokenizer.padding_side = 'right'\n",
    "    model_dpo.to(torch.float32)\n",
    "\n",
    "\n",
    "if only_test or checkpoint != -1:\n",
    "    remake_dataset()\n",
    "\n",
    "dataset, dataset['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "03f8f012",
   "metadata": {},
   "outputs": [],
   "source": [
    "#重载模型\n",
    "def reload_model(epoch):\n",
    "    global model_dpo\n",
    "    global model_dpo_ref\n",
    "\n",
    "    path = 'model/dpo_%d.model' % epoch\n",
    "    model_dpo.save_pretrained(path)\n",
    "    model_dpo = AutoModelForCausalLM.from_pretrained(path).to(device)\n",
    "    model_dpo_ref = AutoModelForCausalLM.from_pretrained(path).to(device)\n",
    "\n",
    "\n",
    "# reload_model(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "16be7af3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, TrainerCallback\n",
    "from trl import DPOTrainer\n",
    "\n",
    "\n",
    "def retrain():\n",
    "\n",
    "    class MyCallback(TrainerCallback):\n",
    "\n",
    "        def on_step_end(self, args, state, control, **kwargs):\n",
    "            if state.global_step % 250 == 0:\n",
    "                print(state.global_step)\n",
    "                return\n",
    "\n",
    "                data = random.choice(dataset['test'])\n",
    "                input_ids = tokenizer.encode(data['prompt'],\n",
    "                                             return_tensors='pt').to(device)\n",
    "\n",
    "                out = model_dpo.generate(input_ids,\n",
    "                                         min_length=-1,\n",
    "                                         top_k=1,\n",
    "                                         top_p=1.0,\n",
    "                                         do_sample=True,\n",
    "                                         pad_token_id=tokenizer.pad_token_id,\n",
    "                                         max_new_tokens=35,\n",
    "                                         eos_token_id=tokenizer.eos_token_id)\n",
    "\n",
    "                print(tokenizer.decode(out[0]))\n",
    "                print('=================')\n",
    "                print(data['chosen'])\n",
    "                print(data['rejected'])\n",
    "                print('=================')\n",
    "\n",
    "    args = TrainingArguments(output_dir='output_dir',\n",
    "                             learning_rate=1e-5,\n",
    "                             per_device_train_batch_size=4,\n",
    "                             max_steps=5000,\n",
    "                             evaluation_strategy='no',\n",
    "                             report_to='none',\n",
    "                             save_strategy='no')\n",
    "\n",
    "    trainer = DPOTrainer(model_dpo,\n",
    "                         model_dpo_ref,\n",
    "                         args=args,\n",
    "                         beta=0.1,\n",
    "                         train_dataset=dataset['train'],\n",
    "                         tokenizer=tokenizer,\n",
    "                         max_length=100,\n",
    "                         max_target_length=100,\n",
    "                         max_prompt_length=100,\n",
    "                         callbacks=[MyCallback()])\n",
    "\n",
    "    trainer.train()\n",
    "\n",
    "\n",
    "# retrain()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e3a5bd33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompt -> context:CREATE TABLE table_name_11 (division INTEGER, reg_season VARCHAR) question:Who was the lowest division in the 7th season? answer:\n",
      "chosen -> SELECT MIN(division) FROM table_name_11 WHERE reg_season = \"7th\"\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE cinema (LOCATION VARCHAR) question:Show each location and the number of cinemas there. answer:\n",
      "chosen -> SELECT LOCATION, COUNT(*) FROM cinema GROUP BY LOCATION\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE Sections (section_name VARCHAR, section_description VARCHAR) question:What are the names and descriptions of all the sections? answer:\n",
      "chosen -> SELECT section_name, section_description FROM Sections\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE table_name_41 (venue VARCHAR, date VARCHAR) question:What is the venue of the game that was played on 23 October 1966? answer:\n",
      "chosen -> SELECT venue FROM table_name_41 WHERE date = \"23 october 1966\"\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE table_26077092_7 (pick__number INTEGER, player VARCHAR) question:What was the pick number for Andrew Quarless?  answer:\n",
      "chosen -> SELECT MAX(pick__number) FROM table_26077092_7 WHERE player = \"Andrew Quarless\"\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE table_name_27 (film VARCHAR, year VARCHAR) question:What film was made in 1999? answer:\n",
      "chosen -> SELECT film FROM table_name_27 WHERE year = \"1999\"\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE table_name_82 (origin VARCHAR, owner VARCHAR) question:What is the origin for the item with an owner of Hunan Broadcasting System (HBS)? answer:\n",
      "chosen -> SELECT origin FROM table_name_82 WHERE owner = \"hunan broadcasting system (hbs)\"\n",
      "rejected -> \n",
      "=========\n",
      "prompt -> context:CREATE TABLE table_name_7 (region VARCHAR, date VARCHAR) question:From which region is the album with release date of 19 June 2007? answer:\n",
      "chosen -> SELECT region FROM table_name_7 WHERE date = \"19 june 2007\"\n",
      "rejected -> \n",
      "=========\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.848, 0.8573333333333333)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test():\n",
    "    sample = random.choices(dataset['test'], k=8)\n",
    "    for i in sample:\n",
    "        for k, v in i.items():\n",
    "            print(k, '->', v)\n",
    "        print('=========')\n",
    "\n",
    "    def correct(data, lower):\n",
    "        rejected = data['rejected']\n",
    "        chosen = data['chosen']\n",
    "\n",
    "        if rejected == '':\n",
    "            return True\n",
    "\n",
    "        if lower:\n",
    "            chosen = chosen.lower().replace('\"', '\\'')\n",
    "            rejected = rejected.lower().replace('\"', '\\'')\n",
    "\n",
    "        return chosen == rejected\n",
    "\n",
    "    def accuracy(lower):\n",
    "        return sum([correct(i, lower)\n",
    "                    for i in dataset['test']]) / len(dataset['test'])\n",
    "\n",
    "    return accuracy(False), accuracy(True)\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6dec7111",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if not only_test:\n",
    "    for epoch in range(checkpoint + 1, 100):\n",
    "        retrain()\n",
    "        reload_model(epoch)\n",
    "        remake_dataset()\n",
    "\n",
    "        print('epoch', epoch, 'test:', test())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt2]",
   "language": "python",
   "name": "conda-env-pt2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
